{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import preprocessor as p\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run ../twitter16/twitter16_text_processing.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle as pkl\n",
    "from collections import defaultdict\n",
    "import pandas as pd\n",
    "import os\n",
    "import numpy as np\n",
    "import json\n",
    "from tqdm import tqdm, tqdm_notebook\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import classification_report, f1_score, accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import spacy\n",
    "from tqdm import tqdm, tqdm_notebook, tnrange\n",
    "import pandas as pd\n",
    "from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Node:\n",
    "    def __init__(self,uid,tid,time_stamp,label):\n",
    "        self.children = {}\n",
    "        self.childrenList = []\n",
    "        self.num_children = 0\n",
    "        self.tid = tid\n",
    "        self.uid = uid\n",
    "        self.label = label\n",
    "        self.time_stamp = time_stamp\n",
    "    \n",
    "    def add_child(self,node):\n",
    "        if node.uid not in self.children:\n",
    "            self.children[node.uid] = node\n",
    "            self.num_children += 1\n",
    "        else:\n",
    "            self.children[node.uid] = node\n",
    "        self.childrenList = list(self.children.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Tree:\n",
    "    def __init__(self,root):\n",
    "        self.root = root\n",
    "        self.tweet_id = root.tid\n",
    "        self.uid = root.uid\n",
    "        self.height = 0\n",
    "        self.nodes = 0\n",
    "    \n",
    "    def show(self):\n",
    "        queue = [self.root,0]\n",
    "        \n",
    "        while len(queue) != 0:\n",
    "            toprint = queue.pop(0)\n",
    "            if toprint == 0:\n",
    "                print('\\n')\n",
    "            else:\n",
    "                print(toprint.uid,end=' ')\n",
    "                queue += toprint.children.values()\n",
    "                queue.append(0)\n",
    "                \n",
    "    def insertnode(self,curnode,parent,child):\n",
    "        if curnode.uid == parent.uid:\n",
    "            curnode.add_child(child)\n",
    "            return 1\n",
    "\n",
    "        elif parent.uid in curnode.children:\n",
    "            s = self.insertnode(curnode.children[parent.uid],parent,child)\n",
    "            return 2\n",
    "        else:\n",
    "            for node in curnode.children:\n",
    "                s = self.insertnode(curnode.children[node],parent,child)\n",
    "                if s == 2:\n",
    "                    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadPklFileNum(datapath,incSize,fileNum):\n",
    "    \n",
    "    with open(datapath+str(incSize)+'inc_'+str(fileNum)+'.pickle', 'rb') as handle:\n",
    "        twitTrees = pkl.load(handle)\n",
    "    return twitTrees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loadTreeFilesOfIncrement(datapath,incSize):\n",
    "    twittertrees = {}\n",
    "    \n",
    "    files = [x for x in os.listdir(t15Datapath) if str(incSize)+'inc' in x]\n",
    "    \n",
    "    for file in tqdm(files):\n",
    "        with open(datapath+file,'rb') as handle:\n",
    "            partialTrees = pkl.load(handle)\n",
    "        twittertrees.update(partialTrees)\n",
    "        \n",
    "    return twittertrees"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t15Datapath = '../twitter16/pickledTrees/'\n",
    "# twitter15_trees = loadPklFileNum(t15Datapath,20,1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "twitter15_trees = loadTreeFilesOfIncrement(t15Datapath,20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run ../twitter16/userdata_parser.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for key in tqdm(userVects):\n",
    "    userVects[key] = userVects[key].float()\n",
    "\n",
    "userVects = defaultdict(lambda:torch.tensor([1.1100e+02, 1.5000e+01, 0.0000e+00, 7.9700e+02, 4.7300e+02, 0.0000e+00,\n",
    "        8.3326e+04, 1.0000e+00]),userVects)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "%run ./textEncoders.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%run ./temporal_tree_model.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = 'cuda:1'\n",
    "    device = 'cpu'\n",
    "else:\n",
    "    device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "labelMap = {}\n",
    "labelCount = 0\n",
    "for label in list(twitter15_labels.values()):\n",
    "    if label not in labelMap:\n",
    "        labelMap[label] = labelCount\n",
    "        labelCount += 1\n",
    "labelMap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 10\n",
    "X = []\n",
    "y = []\n",
    "X_text = []\n",
    "\n",
    "for tid in twitter15_trees:\n",
    "        if tid in twitter15_trees and tid in twitter15_labels:\n",
    "            X.append(tuple((twitter15_trees[tid],twitter15_text[tid])))\n",
    "            y.append(labelMap[twitter15_labels[tid]])\n",
    "            X_text.append(twitter15_text[tid])\n",
    "            \n",
    "x_train,x_test,y_train,y_test = train_test_split(X,y,random_state=2018)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "textparams = {\n",
    "    'embedding_dim':256,\n",
    "    'hidden_dim': 50,\n",
    "    'output_dim':4,\n",
    "    'num_layers':1,\n",
    "    'bidir':True,\n",
    "    'rnnType':'gru'\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = torch.nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = temporalDecayTreeEncoder(torch.cuda.is_available(),8,100,userVects,twitter15_labels,labelMap,criterion,device)\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "f = lambda m, n: [(i*n//m + n//(2*m)) for i in range(m)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'torch' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-b65f4e3cfd1c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mAdagrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mmaxAcc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adagrad(model.parameters(),lr=0.01)\n",
    "\n",
    "count = 0\n",
    "maxAcc = 0\n",
    "\n",
    "train_iterwise = []\n",
    "val_iterwise = []\n",
    "\n",
    "for i in range(7):\n",
    "    train_losses = []\n",
    "    val_losses = []\n",
    "    \n",
    "    for treeSet,text in tqdm(x_train):\n",
    "        tnum = 0\n",
    "        \n",
    "        idxs = f(5,len(treeSet))\n",
    "#         for idx in idxs:\n",
    "        trees = [ treeSet[idx] for idx in idxs ]\n",
    "        count += 1\n",
    "        tnum += 1\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        pred = model(trees)\n",
    "        label = Variable(torch.tensor(labelMap[trees[0].root.label]))\n",
    "\n",
    "        loss = criterion(pred.reshape(-1,4), label.reshape(-1))    \n",
    "\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "            \n",
    "    preds = []\n",
    "    labels = []\n",
    "\n",
    "    allLabels = []\n",
    "    allPreds = []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for valSet,text in tqdm(x_test):\n",
    "            idxs = f(5,len(valSet))\n",
    "            trees = [ valSet[idx] for idx in idxs ]\n",
    "            \n",
    "            predicted = model(trees)\n",
    "            preds.append(predicted)\n",
    "\n",
    "            predicted =  torch.softmax(predicted[0],0)\n",
    "            predicted = torch.max(predicted, 0)[1].cpu().numpy().tolist()\n",
    "\n",
    "            labels.append(labelMap[trees[0].root.label])\n",
    "\n",
    "            allLabels.append(labelMap[trees[0].root.label])\n",
    "            allPreds.append(predicted)\n",
    "\n",
    "    predTensor = torch.stack(preds)\n",
    "    labelTensor = torch.tensor(labels).to(device)\n",
    "\n",
    "    print(allLabels,allPreds)\n",
    "\n",
    "    loss = criterion(predTensor.reshape(-1,4), labelTensor.reshape(-1))\n",
    "\n",
    "    cr = classification_report(allLabels,allPreds,output_dict=True)\n",
    "    cr['loss'] = loss.item()\n",
    "    cr['Acc'] = accuracy_score(allLabels,allPreds,)\n",
    "    print('loss: ',cr['loss'])\n",
    "    print(cr['Acc'])\n",
    "    \n",
    "    if cr['Acc'] > maxAcc:\n",
    "            maxAcc = cr['Acc']\n",
    "            torch.save({'state_dict': model.state_dict()}, './pretrainedModels-Twit16/'+'decay_tempTreeEnc_pretrained_inc5'+'.pth')\n",
    "    \n",
    "    with open('./pretrainedModels-Twit16/'+'decay_tempTreeEnc_pretrained_inc5'+'.json', 'a') as fp:\n",
    "            json.dump(cr, fp)\n",
    "            fp.write('\\n')\n",
    "        \n",
    "    val_losses.append(loss.item())\n",
    "    train_iterwise.append(np.array(train_losses).mean())\n",
    "    val_iterwise.append(np.array(val_losses).mean())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "rnnTypes = ['gru','lstm']\n",
    "bidirTypes = [True,False]\n",
    "\n",
    "for rnnType in rnnTypes:\n",
    "    for bidirType in bidirTypes:\n",
    "        textparams[bidir] = bidirType\n",
    "        textparams[rnnType] = rnnType\n",
    "        \n",
    "        textEncoderModel = TextEncoder('rnn',textparams,X_text,y,device)\n",
    "        textEncoderModel.trainModel()\n",
    "        print(len(textEncoderModel.word2idx))\n",
    "        \n",
    "        modelname = rnnType\n",
    "        if bidirType:\n",
    "            modelname = 'bidir'+modelname\n",
    "        \n",
    "        torch.save({'state_dict': textEncoderModel.optimalParams}, './pretrainedModels-Twit15/'+modelname+'.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
